############################################
# author: H. Wang (hwang@cfa.harvard.edu)
# purpose:  
#  simple functions for use with Version 2 MARCI MDGM &
#  MDSSD downloaded from the Harvard Dataverse
# 
# note:
#   mdssdfnm and mdgmfnm should include full path
#
# history:
#  initial release April 2023
#
############################################
# read and plot mdssds result over MDGM
############################################
#
import scipy.io as sio
import numpy as np
import os
#------
def read_mdssd(mdssdfnm):
   a0 = sio.readsav(mdssdfnm)
   results = a0['results']
   output = results[0]
   names = results.dtype.names
   return output,names

#------  
def read_mdgm(mdgmfnm):
   from matplotlib import image
   # find filename extenstion

   _,fnmext = os.path.splitext(mdgmfnm)

   # this will read south pole at top
   img = image.imread(mdgmfnm)

   # if mdgmfnm is a png, img is in [0,1]
   # if mdgmfnm is a jpg, img is in [0, 255]
   if (fnmext == '.png'):
      img = img * 255.

   # flip to make north pole at top
   array = img.astype(int)
   imageout = np.flip(array, axis=0)

   return imageout
#------ 
def display_mdgm(fnm,xsize=8,ysize=6):
   from matplotlib import pyplot as plt

   array=read_mdgm(fnm)
   plt.figure(figsize=(xsize,ysize))
   plt.imshow(array,origin='lower')
   title = fnm
   plt.title(fnm)
   plt.axis('off')
   plt.show(block=False)

#------   
def highlight_mdssd(mdgmfnm,mdssdfnm,xsize=8,ysize=6):
# mdgmfnm and mdssdfnm should pair up
# mdssdfnm ROIs will be highlighted on mdgmfnm
   from matplotlib import pyplot as plt

   # read MDGM as RGB image
   image = read_mdgm(mdgmfnm)
   # select the green channel
   green = image[:,:,1]

   # read MDSSD results
   output,names = read_mdssd(mdssdfnm)

   # decrease green channel pixel values, so that 
   # objects will be highlighted in purple in RGB image
   # annotation position is saved in mxmy
   # annotation text is in labels 
   nobj = len(names)
   mxmy = np.zeros((nobj,2)).astype(int)
   memids = ['']
   for k in range(nobj) :
       thisobj = output[k]
       roix = thisobj.roix[0]
       roiy = thisobj.roiy[0]
       mxmy[k,:]=np.array([np.median(roix),np.median(roiy)]).astype(int)
       thisid = str(thisobj.storm_id[0],'utf-8')
       memids.append(thisid)
       for i in range(len(roix)) :
           icol = roix[i]
           jrow = roiy[i]
           green[jrow,icol] = green[jrow,icol]/2.0
   
   image[:,:,1] = green
   labels = memids[1:]

   # matplotlib display
   plt.figure(figsize=(xsize,ysize))
   plt.imshow(image,origin='lower')
   for k in range(nobj):
       thisname = labels[k]
       plt.annotate(thisname,mxmy[k,:])
   plt.title(mdgmfnm)
   plt.axis('off')
   plt.show(block=False)

#------   
